package com.omega.engine.nn.layer.lpips.gpu;

import static jcuda.driver.JCudaDriver.cuLaunchKernel;

import com.omega.common.data.Tensor;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.gpu.CUDAModules;

import jcuda.Pointer;
import jcuda.driver.CUfunction;
import jcuda.runtime.cudaError;

public class LPIPSKernel extends BaseKernel{
	
	private CUfunction lpip_l2_function;
	
	private CUfunction lpip_l2_backward_function;
	
	private CUfunction scaling_function;
	
	private CUfunction scaling_backwad_function;
	
	private int CAFFE_CUDA_NUM_THREADS = 1024;
	
	private Pointer forwardKernelParameters;

	
	public LPIPSKernel() {
		init();
	}
	
	public void initFunction() {
		
		try {
			
			if(lpip_l2_function == null) {
				
				lpip_l2_function = CUDAModules.getLocalFunctionByModule("lpipsKernel.cu", "lpip_l2_kernel");
				
			}
			
			if(lpip_l2_backward_function == null) {
				
				lpip_l2_backward_function = CUDAModules.getLocalFunctionByModule("lpipsKernel.cu", "lpip_l2_backward_kernel");
				
			}
			
			if(scaling_function == null) {
				
				scaling_function = CUDAModules.getLocalFunctionByModule("lpipsKernel.cu", "scaling_kernel");
				
			}
			
			if(scaling_backwad_function == null) {
				
				scaling_backwad_function = CUDAModules.getLocalFunctionByModule("lpipsKernel.cu", "scaling_back_kernel");
				
			}
			
		} catch (Exception e) {
			// TODO: handle exception
			e.printStackTrace();
		}
		
	}
	
	public void init() {
		/**
		 * 初始化cuda函数
		 */
		initFunction();
	}
	
	
	public int CAFFE_GET_BLOCKS(int N) {
	    return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
	}
	
	public void lpip_l2(Tensor x1,Tensor x2,Tensor output) {
		
		try {


	        /**
	         * 设置入参
	         * float *x1,float *x2, float *out, int N
	         */
			forwardKernelParameters = Pointer.to(
					Pointer.to(x1.getGpuData()),
					Pointer.to(x2.getGpuData()),
					Pointer.to(output.getGpuData()),
					 Pointer.to(new int[]{x1.getDataLength()})
	            );
			
			cuLaunchKernel(lpip_l2_function,
		            this.CAFFE_GET_BLOCKS(x1.getDataLength()),  1, 1,      // Grid dimension
		            CAFFE_CUDA_NUM_THREADS, 1, 1,      // Block dimension
		            0, null,               // Shared memory size and stream
		            forwardKernelParameters, null // Kernel- and extra parameters
		        );
			
		} catch (Exception e) {
			// TODO: handle exception
			e.printStackTrace();
		}
		
	}
	
	public void lpip_l2_backward(Tensor delta,Tensor x1,Tensor x2,Tensor diff) {
		
		try {


	        /**
	         * 设置入参
	         * float *delta,float *x1,float *x2, float *diff, int N
	         */
			forwardKernelParameters = Pointer.to(
					Pointer.to(delta.getGpuData()),
					Pointer.to(x1.getGpuData()),
					Pointer.to(x2.getGpuData()),
					Pointer.to(diff.getGpuData()),
					 Pointer.to(new int[]{delta.getDataLength()})
	            );
			
			cuLaunchKernel(lpip_l2_backward_function,
		            this.CAFFE_GET_BLOCKS(delta.getDataLength()),  1, 1,      // Grid dimension
		            CAFFE_CUDA_NUM_THREADS, 1, 1,      // Block dimension
		            0, null,               // Shared memory size and stream
		            forwardKernelParameters, null // Kernel- and extra parameters
		        );
			
		} catch (Exception e) {
			// TODO: handle exception
			e.printStackTrace();
		}
		
	}
	
	public void scaling(Tensor x,Tensor shift,Tensor scale,Tensor output) {
		
		try {

	        /**
	         * 设置入参
	         * int N,float *x,float *shift, float *scale, float *out, int C,int HW
	         */
			forwardKernelParameters = Pointer.to(
					Pointer.to(new int[]{x.getDataLength()}),
					Pointer.to(x.getGpuData()),
					Pointer.to(shift.getGpuData()),
					Pointer.to(scale.getGpuData()),
					Pointer.to(output.getGpuData()),
					Pointer.to(new int[]{x.channel}),
					Pointer.to(new int[]{x.height * x.width})
	            );
			
			cuLaunchKernel(scaling_function,
		            this.CAFFE_GET_BLOCKS(x.getDataLength()),  1, 1,      // Grid dimension
		            CAFFE_CUDA_NUM_THREADS, 1, 1,      // Block dimension
		            0, null,               // Shared memory size and stream
		            forwardKernelParameters, null // Kernel- and extra parameters
		        );
			
		} catch (Exception e) {
			// TODO: handle exception
			e.printStackTrace();
		}
		
	}
	
	public void scaling_backwad(Tensor dy,Tensor scale,Tensor dx) {
		
		try {

			 /**
	         * 设置入参
	         * int N,float *dy, float *scale, float *dx, int C,int HW
	         */
			forwardKernelParameters = Pointer.to(
					Pointer.to(new int[]{dy.getDataLength()}),
					Pointer.to(dy.getGpuData()),
					Pointer.to(scale.getGpuData()),
					Pointer.to(dx.getGpuData()),
					Pointer.to(new int[]{dy.channel}),
					Pointer.to(new int[]{dy.height * dy.width})
	            );
			
			cuLaunchKernel(scaling_backwad_function,
		            this.CAFFE_GET_BLOCKS(dy.getDataLength()),  1, 1,      // Grid dimension
		            CAFFE_CUDA_NUM_THREADS, 1, 1,      // Block dimension
		            0, null,               // Shared memory size and stream
		            forwardKernelParameters, null // Kernel- and extra parameters
		        );
			
		} catch (Exception e) {
			// TODO: handle exception
			e.printStackTrace();
		}
		
	}
	
	public void checkCUDA(int code) {
		if(code != cudaError.cudaSuccess) {
			System.err.println("Error code "+code+":"+cudaError.stringFor(code));
		}
	}
	
    public static void main(String args[]){	
    	
    	int N = 2;
    	int C = 3;
    	int H = 2;
    	int W = 2;
    	
    	float[] data = RandomUtils.order(N * C * H * W, 1f, 1f);
    	
    	Tensor input = new Tensor(N, C, H, W, data, true);
    	
    	Tensor output = new Tensor(N, C, H, W, true);
    	
    	Tensor shift = new Tensor(1, 1, 1, 3, new float[] {-0.03f, -0.088f, -0.188f}, true);
		
    	Tensor scale = new Tensor(1, 1, 1, 3, new float[] {0.458f, 0.448f, 0.45f}, true);
    	
    	LPIPSKernel kernel = new LPIPSKernel();
    	
    	kernel.scaling(input, shift, scale, output);
    	
    	output.showDM();
    	
    	kernel.scaling_backwad(input, scale, output);
    	
    	output.showDM();
    	
    }

}
